# We organise them in two levels
# - sampling rule; a factory for sampling rule states
# - sampling rule state; keeps track of i.e. tracking information etc.

import Distributions;

include("../regret.jl");
include("../tracking.jl");
include("../expfam.jl");

```
Uniform sampling
```

struct RoundRobin # used as factory and state
end

long(sr::RoundRobin) = "Uniform";
abbrev(sr::RoundRobin) = "RR";

function start(sr::RoundRobin, N)
    return sr;
end

function nextsample(sr::RoundRobin, pep, astar, aalt, ξ, N, Xs, rng)
    return 1 + (sum(N) % length(N));
end

```
Oracle sampling
```

struct FixedWeights # used as factory and state
    w;
    function FixedWeights(w)
        @assert all(w .≥ 0) && sum(w) ≈ 1 "$w not in simplex";
        new(w)
    end
end

long(sr::FixedWeights) = "Oracle Weigths";
abbrev(sr::FixedWeights) = "opt";

function start(sr::FixedWeights, N)
    return sr;
end

function nextsample(sr::FixedWeights, pep, astar, aalt, ξ, N, Xs, rng)
    argmin(N .- sum(N).*sr.w);
end


```
Track-and-Stop
```

# Note: For bounded distrbutions, this will take forever

struct TrackAndStop
    TrackingRule;
end

long(sr::TrackAndStop) = "TaS " * abbrev(sr.TrackingRule);
abbrev(sr::TrackAndStop) = "TaS-" * abbrev(sr.TrackingRule);

struct TrackAndStopState
    t;
    TrackAndStopState(TrackingRule, N) = new(ForcedExploration(TrackingRule(N)));
end

function start(sr::TrackAndStop, N)
    TrackAndStopState(sr.TrackingRule, N);
end

function nextsample(sr::TrackAndStopState, pep, astar, aalt, ξ, N, Xs, rng)
    _, w = oracle(pep, ξ, Xs);

    # tracking
    return track(sr.t, N, w);
end


```
TopTwo
```

struct TopTwo
    β;
    leader;
    challenger;
end

long(sr::TopTwo) = sr.leader * "-" * sr.challenger * "-" * split(string(sr.β), ".")[2];
abbrev(sr::TopTwo) = sr.leader * "-" * sr.challenger * "-" * split(string(sr.β), ".")[2];

function start(sr::TopTwo, N)
    sr;
end

function nextsample(sr::TopTwo, pep, astar, aalt, ξ, N, Xs, rng)
    K = nanswers(pep);
    B = getB(pep);
    _isBer = isBer(pep);

    # Dirichlet distributions
    if sr.leader == "TS" || sr.challenger == "RS"
        if _isBer
            betas = [Distributions.Beta(1 + sum(Xs[k]), 1 + N[k] - sum(Xs[k])) for k in 1:K]
        else
            dirichlets = [Distributions.Dirichlet(N[k] + 2, 1) for k in 1:K];
        end
    end

    # Leader
    if sr.leader == "EB"
        a_1 = astar;
    elseif sr.leader == "TS"
        if _isBer
            hξ = [rand(rng, beta) for beta in betas];
        else
            ws = [rand(rng, dir) for dir in dirichlets];
            hξ = [sum(Xs[k] .* ws[k][3:end]) + B * ws[k][1] for k in 1:K];
        end
        a_1 = argmax(hξ);
    else
        @error "Undefined Top Two leader";
    end

    # Challenger
    u = rand(rng);
    if u <= sr.β
        k = a_1;
    else
        if sr.challenger == "TC"
            if a_1 != astar
                # Sampling uniformly at random among the arms with highest mean, since they have null transporation cost
                # Most of the time, this will be a singleton, hence returning astar
                ks = [a for a in 1:K if ξ[a] >= ξ[a_1] && a != a_1]
                k = ks[rand(rng, 1:length(ks))];
            else
                # When a_1 == astar, the computations from the stopping rule can be used as they rely on the same transportation costs
                k = aalt;
            end
        elseif sr.challenger == "RS"
            ks = [a_1];
            count = 0;
            while a_1 in ks
                if _isBer
                    hξ = [rand(rng, beta) for beta in betas];
                else
                    ws = [rand(rng, d) for d in dirichlets];
                    hξ = [sum(Xs[a] .* ws[a][3:end]) + B * ws[a][1] for a in 1:K];
                end
                max_hξ = maximum(hξ);
                ks = [a for a in 1:K if hξ[a] == max_hξ];
                count += 1;

                if count > 1e6
                    @warn "RS challenger is taking too much time, hence sample uniformly.";
                    return 1 + (sum(N) % length(N));
                end
            end
            k = ks[rand(rng, 1:length(ks))];
        elseif occursin("TCI", sr.challenger)
            if a_1 != astar
                # Among the arms with higher mean (null transporation cost), it boils down to take the least sampled ones.
                ks = [a for a in 1:K if ξ[a] >= ξ[a_1] && a != a_1];
                min_N = minimum(N[ks]);
                k_min_N = [a for a in ks if N[a] == min_N];
                k = k_min_N[rand(rng, 1:length(k_min_N))];
            else
                # When a_1 == astar, we should compute glr with a bonus/penalization
                _, (k, _), (_, _) = glrt(pep, N, ξ, Xs, split(sr.challenger, "-")[2]);
            end
        else
            @error "Undefined Top Two challenger";
        end
    end

    return k;
end

```
LUCB
```

struct LUCB
    type;
end

long(sr::LUCB) = sr.type * "-LUCB";
abbrev(sr::LUCB) = sr.type * "-LUCB";

function start(sr::LUCB, N)
    sr;
end

function nextsample(sr::LUCB, pep, ξ, N, Xs, β)
    K = nanswers(pep);
    B = getB(pep);
    _isBer = isBer(pep);
    astar = argmax(ξ);

    UCBs = zeros(K);
    LCB = 0;
    if sr.type == "Ho"
        UCBs = ξ .+ sqrt.(β ./ (2 * N));
        UCBs[astar] = -Inf;
        LCB = ξ[astar] - sqrt(β / (2 * N[astar]));
    elseif sr.type == "KL"
        ber = Bernoulli()
        LCB = ddn(ber, ξ[astar], β / N[astar]);
        for a in 1:K
            if a != astar
                UCBs[a] = dup(ber, ξ[a], β / N[a]);
            end
        end
    elseif sr.type == "Ki"
        LCB = Kinf_dn(_isBer, Xs[astar], ξ[astar], β / N[astar], B);
        for a in 1:K
            if a != astar
                UCBs[a] = Kinf_up(_isBer, Xs[a], ξ[a], β / N[a], B);
            end
        end
    else
        @error "Unrecognized type of LUCB"
    end

    # Best challenger to astar
    k = argmax(UCBs);

    # Compute gap
    ucb = UCBs[k] - LCB;
    return astar, k, ucb;
end


```
NPGame
```

struct NPGame
    TrackingRule;
    with_fe;
end

long(sr::NPGame) = (sr.with_fe ? "F-" : "") * "NPGame " * abbrev(sr.TrackingRule);
abbrev(sr::NPGame) = (sr.with_fe ? "F-" : "") * "NPGame-" * abbrev(sr.TrackingRule);

struct NPGameState
    h; # one online learner in total
    t;
    with_fe;
    NPGameState(TrackingRule, N, with_fe) = new(AdaHedge(length(N)), with_fe ? ForcedExploration(TrackingRule(N)) : TrackingRule(N), with_fe);
end

function start(sr::NPGame, N)
    NPGameState(sr.TrackingRule, N, sr.with_fe);
end

function nextsample(sr::NPGameState, pep, astar, aalt, ξ, N, Xs, rng)
    B = getB(pep);
    _isBer = isBer(pep);

    # query the learner
    w = act(sr.h);

    # best response λ-player to w
    _, (k, λs), (_, _) = glrt(pep, w, ξ, Xs);

    # optimistic gradient
    ∇ = zeros(length(ξ));
    ∇[astar] = max(Kinf_emp(_isBer, Xs[astar], ξ[astar], λs[astar], B, false), log(sum(N)) / N[astar]);
    ∇[k] = max(Kinf_emp(_isBer, Xs[k], ξ[k], λs[k], B, true), log(sum(N)) / N[k]);
    incur!(sr.h, -∇);

    # tracking
    return track(sr.t, N, w);
end
